Search Options

Results per page
Sort
Preferred Languages
Advance

Results 1 - 10 of 11 for BatchMatMulV3 (0.34 sec)

  1. tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir

    // -----
    
    func.func @batchMatMulV3MatrixInt8(%arg0: tensor<4x5xi8>, %arg1: tensor<5x6xi8>) -> tensor<4x6xi32> {
      %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xi8>, tensor<5x6xi8>) -> tensor<4x6xi32>
      func.return %0 : tensor<4x6xi32>
    
      // CHECK-LABEL: batchMatMulV3MatrixInt8
      // CHECK: %0 = "tf.BatchMatMulV3"(%arg0, %arg1) : (tensor<4x5xi8>, tensor<5x6xi8>) -> tensor<4x6xi32>
      // CHECK: return %0 : tensor<4x6xi32>
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Dec 06 18:42:28 UTC 2023
    - 63.7K bytes
    - Viewed (0)
  2. tensorflow/cc/gradients/math_grad.cc

        grad_outputs->push_back(dy);
      } else {
        auto dx = BatchMatMulV3(scope, x0, x1, x_data_type,
                                BatchMatMulV3::AdjX(adj_x0).AdjY(adj_x1));
        grad_outputs->push_back(dx);
        auto dy = BatchMatMulV3(scope, y0, y1, y_data_type,
                                BatchMatMulV3::AdjX(adj_y0).AdjY(adj_y1));
        grad_outputs->push_back(dy);
      }
      return scope.status();
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Fri Aug 25 18:20:20 UTC 2023
    - 50.7K bytes
    - Viewed (0)
  3. tensorflow/cc/gradients/math_grad_test.cc

        TestMatMulGradHelper<T>(
            /*is_x_batch=*/true, /*is_y_batch=*/true, t_x, t_y,
            [&](Output x, Output y) {
              return BatchMatMulV3(root_, x, y, DataTypeToEnum<T>::v(),
                                   BatchMatMulV3::AdjX(t_x).AdjY(t_y));
            });
      }
    
      template <typename T>
      void TestMatMulGradHelper(const bool is_x_batch, const bool is_y_batch,
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Fri Aug 25 18:20:20 UTC 2023
    - 36K bytes
    - Viewed (0)
  4. tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc

      if (lhs_type.getElementType().isInteger(8) &&
          rhs_type.getElementType().isInteger(8)) {
        return rewriter.notifyMatchFailure(op,
                                           "skip unrolling for int8 BatchMatMulV3");
      }
    
      auto element_type = lhs_type.getElementType();
    
      if (element_type != rhs_type.getElementType()) {
        // The element type of LHS must be the same with element type of RHS
        return failure();
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 11.6K bytes
    - Viewed (0)
  5. tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc

    }
    
    INSTANTIATE_TEST_SUITE_P(
        BatchMatMulTest, BatchMatMulTest,
        ::testing::ValuesIn<MatMulTestCase>({
            {"BatchMatMul"},
            {"BatchMatMulV2"},
            {"BatchMatMulV3"},
        }),
        [](const ::testing::TestParamInfo<BatchMatMulTest::ParamType>& info) {
          return info.param.mat_mul_method;
        });
    
    TEST(LegalizeTFTest, DumpsProducedHLO) {
      Env* env = Env::Default();
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Jun 13 23:59:33 UTC 2024
    - 16.1K bytes
    - Viewed (0)
  6. tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir

    // CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32>
    // CHECK:           %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_3]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
    // CHECK:           %[[VAL_5:.*]] = arith.constant dense<1> : tensor<1xi64>
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed May 29 07:26:59 UTC 2024
    - 340.2K bytes
    - Viewed (0)
  7. tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc

    // operands are properly supported in declarative rewrite rule specification.
    
    DECL_CONVERT_OP(Assert);
    DECL_CONVERT_OP(ConcatV2);
    DECL_CONVERT_OP(BatchMatMul);
    DECL_CONVERT_OP(BatchMatMulV2);
    DECL_CONVERT_OP(BatchMatMulV3);
    DECL_CONVERT_OP(MatMul);
    DECL_CONVERT_OP(MatrixDiagV2);
    DECL_CONVERT_OP(MatrixDiagV3);
    DECL_CONVERT_OP(Pack);
    DECL_CONVERT_OP(Split);
    DECL_CONVERT_OP(SplitV);
    DECL_CONVERT_OP(Unpack);
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Mon May 20 20:06:54 UTC 2024
    - 45.2K bytes
    - Viewed (0)
  8. tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir

    }
    
    func.func @matmul_batchv3(%arg0: tensor<2x10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<2x10x17xf32> {
      %0 = "tf.BatchMatMulV3"(%arg0, %arg1) {Ta = "tfdtype$DT_FLOAT", Tb = "tfdtype$DT_FLOAT",device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
    (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32>
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Jun 05 01:54:33 UTC 2024
    - 153.4K bytes
    - Viewed (0)
  9. tensorflow/compiler/jit/mark_for_compilation_pass.cc

    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Feb 21 12:19:41 UTC 2024
    - 85.3K bytes
    - Viewed (0)
  10. tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc

      // Set benefit of this pattern to zero to prefer the fallback pattern when
      // available and applicable. That pattern avoids broadcast on operands and is
      // therefore faster.
      //
      // Native legalization for BatchMatMulV3 needs to be added as well.
      explicit ConvertBatchMatMulV2Op(MLIRContext *context)
          : OpRewritePattern<TF::BatchMatMulV2Op>(context, /*benefit=*/0) {}
    
      LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Tue Jun 11 20:00:43 UTC 2024
    - 291.8K bytes
    - Viewed (0)
Back to top